package org.example.nebula.basic

import com.facebook.thrift.protocol.TCompactProtocol
import org.apache.log4j.Logger
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.example.graphx.local.SSSPExample
import org.example.utils.ArgsUtil

import scala.collection.mutable

object AcceptParam {

  def main(args: Array[String]): Unit = {

    val logger = Logger.getLogger(SSSPExample.getClass.getName)

    // 创建一个SparkSession
    val sparkConf = new SparkConf()
      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol]))
    val spark = SparkSession
      .builder()
      .master("local")
      .config(sparkConf)
      .getOrCreate()

    // 方式1：获取到参数的
    val algoName = args(0)
    val dataPath = args(1)
    val resultPath = args(2)
    println(algoName + "\t" + dataPath + "\t" + resultPath)

    logger.error("-------------------------------------------")

    // 方式2：获取到参数的
    var argMap:mutable.Map[String, String] = mutable.Map();
    args.foreach(v=> {
      val argArray: Array[String] = v.split("=")
      argMap.put(argArray(0), argArray(1))
    })

    for(key <- argMap.keys) {
      println(key + "\t" + argMap.get(key).get)
    }

    logger.error("-------------------------------------------")

    // 方式3：获取到参数的
    val params = ArgsUtil.parse(args, "=")
    val numIter = params.getOrElse("numIter", "5").toInt
    val input = params.getOrElse("input", "data/a9a/a9a_123d_train.libsvm")
    val lr = params.getOrElse("lr", "0.1").toDouble
    println(numIter + "\t" + input + "\t" + lr)


    spark.stop()
  }
}
